{
 "cells": [
  {
   "cell_type": "markdown",
   "source": [
    " \n",
    "# Federated Learning Training Plan: Execute Plan\n",
    "\n",
    "Here we load and execute Plan and Model params created earlier in \"Create Plan\" notebook. \n",
    "\n",
    "This represents PySyft (python) worker."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "is_executing": false
    }
   },
   "outputs": [
    {
     "name": "stderr",
     "text": [
      "C:\\Users\\Vova\\AppData\\Local\\conda\\conda\\envs\\pysyft\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:526: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint8 = np.dtype([(\"qint8\", np.int8, 1)])\n",
      "C:\\Users\\Vova\\AppData\\Local\\conda\\conda\\envs\\pysyft\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:527: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint8 = np.dtype([(\"quint8\", np.uint8, 1)])\n",
      "C:\\Users\\Vova\\AppData\\Local\\conda\\conda\\envs\\pysyft\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:528: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint16 = np.dtype([(\"qint16\", np.int16, 1)])\n",
      "C:\\Users\\Vova\\AppData\\Local\\conda\\conda\\envs\\pysyft\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:529: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_quint16 = np.dtype([(\"quint16\", np.uint16, 1)])\n",
      "C:\\Users\\Vova\\AppData\\Local\\conda\\conda\\envs\\pysyft\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:530: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  _np_qint32 = np.dtype([(\"qint32\", np.int32, 1)])\n",
      "C:\\Users\\Vova\\AppData\\Local\\conda\\conda\\envs\\pysyft\\lib\\site-packages\\tensorflow\\python\\framework\\dtypes.py:535: FutureWarning: Passing (type, 1) or '1type' as a synonym of type is deprecated; in a future version of numpy, it will be understood as (type, (1,)) / '(1,)type'.\n",
      "  np_resource = np.dtype([(\"resource\", np.ubyte, 1)])\n",
      "Falling back to insecure randomness since the required custom op could not be found for the installed version of TensorFlow. Fix this by compiling custom ops. Missing file was 'C:\\Users\\Vova\\AppData\\Local\\conda\\conda\\envs\\pysyft\\lib\\site-packages\\tf_encrypted/operations/secure_random/secure_random_module_tf_1.13.1.so'\n"
     ],
     "output_type": "stream"
    },
    {
     "name": "stdout",
     "text": [
      "Setting up Sandbox...\n",
      "Done!\n"
     ],
     "output_type": "stream"
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "import syft as sy\n",
    "import torch as th\n",
    "from torchvision import datasets, transforms\n",
    "from syft.serde import protobuf\n",
    "from syft_proto.execution.v1.plan_pb2 import Plan as PlanPB\n",
    "from syft_proto.execution.v1.state_pb2 import State as StatePB\n",
    "from syft_proto.types.torch.v1.script_function_pb2 import ScriptFunction as ScriptFunctionPB\n",
    "from syft import PlaceHolder\n",
    "from syft.execution.state import State\n",
    "import os\n",
    "import numpy as np\n",
    "\n",
    "sy.hook(globals())\n",
    "# force protobuf serialization for tensors\n",
    "hook.local_worker.framework = None"
   ]
  },
  {
   "cell_type": "markdown",
   "source": [
    "Utility func that unserializes file contents into PySyft classes.\n",
    "Note that we must know file contents beforehand to use specific protobuf class for deserialization."
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "def deserializeFromBin(worker, filename, pb):\n",
    "    with open(filename, \"rb\") as f:\n",
    "        bin = f.read()\n",
    "    pb.ParseFromString(bin)\n",
    "    return protobuf.serde._unbufferize(worker, pb)\n",
    "\n",
    "def serializeToBinPb(worker, obj, filename):\n",
    "    pb = protobuf.serde._bufferize(worker, obj)\n",
    "    bin = pb.SerializeToString()\n",
    "    print(\"Writing %s to %s/%s\" % (obj.__class__.__name__, os.getcwd(), filename))\n",
    "    with open(filename, \"wb\") as f:\n",
    "        f.write(bin)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step 5: Unserialize "
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "outputs": [
    {
     "name": "stdout",
     "text": [
      "Loaded plan (# of actions): 39\n",
      "Loaded tracescript plan code: def forward(self,\n",
      "    argument_1: Tensor,\n",
      "    argument_2: Tensor,\n",
      "    argument_3: Tensor,\n",
      "    argument_4: Tensor,\n",
      "    argument_5: Tensor,\n",
      "    argument_6: Tensor,\n",
      "    argument_7: Tensor,\n",
      "    argument_8: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:\n",
      "  _0 = torch.matmul(argument_1, torch.t(argument_5))\n",
      "  _1 = torch.add(_0, argument_6, alpha=1)\n",
      "  _2 = torch.relu(_1)\n",
      "  _3 = torch.add(torch.matmul(_2, torch.t(argument_7)), argument_8, alpha=1)\n",
      "  _4 = torch.softmax(_3, 1, None)\n",
      "  _5 = torch.mean(torch.mul(argument_2, torch.log(_4)), dtype=None)\n",
      "  _6 = torch.neg(_5)\n",
      "  _7 = torch.div(torch.sub(_4, argument_2, alpha=1), torch.mul(argument_3, CONSTANTS.c0))\n",
      "  _8 = torch.matmul(_7, argument_7)\n",
      "  _9 = torch.to(torch.gt(_1, 0), 6, False, False, None)\n",
      "  _10 = torch.mul(_8, _9)\n",
      "  _11 = torch.matmul(torch.t(_10), argument_1)\n",
      "  _12 = torch.sum(_10, [0], False, dtype=None)\n",
      "  _13 = torch.matmul(torch.t(_7), _2)\n",
      "  _14 = torch.sum(_7, [0], False, dtype=None)\n",
      "  _15 = torch.sub(argument_5, torch.mul(_11, argument_4), alpha=1)\n",
      "  _16 = torch.sub(argument_6, torch.mul(_12, argument_4), alpha=1)\n",
      "  _17 = torch.sub(argument_7, torch.mul(_13, argument_4), alpha=1)\n",
      "  _18 = torch.sub(argument_8, torch.mul(_14, argument_4), alpha=1)\n",
      "  _19 = torch.eq(torch.argmax(_4, 1, False), torch.argmax(argument_2, 1, False))\n",
      "  _20 = torch.sum(torch.to(_19, 6, False, False, None), dtype=None)\n",
      "  _21 = (_6, torch.div(_20, argument_3), _15, _16, _17, _18)\n",
      "  return _21\n",
      "\n",
      "Loaded params count: 4\n"
     ],
     "output_type": "stream"
    }
   ],
   "source": [
    "training_plan_ops = deserializeFromBin(hook.local_worker, \"tp_ops.pb\", PlanPB())\n",
    "training_plan_ts = deserializeFromBin(hook.local_worker, \"tp_ts.pb\", ScriptFunctionPB())\n",
    "model_params_state = deserializeFromBin(hook.local_worker, \"model_params.pb\", StatePB())\n",
    "# unwrap tensors from State\n",
    "model_params = model_params_state.tensors()\n",
    "\n",
    "print(\"Loaded plan (# of actions):\", len(training_plan_ops.actions))\n",
    "print(\"Loaded tracescript plan code:\", training_plan_ts.code)\n",
    "print(\"Loaded params count:\", len(model_params))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step 6: Train!\n",
    "\n",
    "Define the full training procedure that uses Plan as one training step on a batch of data. "
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [],
   "source": [
    "batch_size = 64\n",
    "mnist = th.utils.data.DataLoader(\n",
    "    datasets.MNIST('data', train=True, download=True, transform=transforms.ToTensor()),\n",
    "    batch_size=batch_size,\n",
    "    shuffle=True\n",
    ")\n",
    "\n",
    "def execute_training_plan(data, plan, model_params, epochs=1, batch_size=th.tensor(batch_size), lr=th.tensor(0.01)):\n",
    "    for epoch in range(1, epochs+1):\n",
    "        losses = []\n",
    "        accuracies = []\n",
    "        for batch_idx, (X, y) in enumerate(data):\n",
    "            X = X.view(X.shape[0], -1)\n",
    "            y_oh = th.nn.functional.one_hot(y, 10)\n",
    "            loss, acc, *model_params = plan(X, y_oh, batch_size, lr, *model_params)\n",
    "            losses.append(loss.item())\n",
    "            accuracies.append(acc.item())\n",
    "            if batch_idx % 100 == 0:\n",
    "                print(\"Batch %d, loss: %f, accuracy: %f\" % (batch_idx, loss, acc), end=\"\\r\")\n",
    "        print('Epoch %d, avg loss: %f, avg training accuracy: %f' % (epoch, np.mean(losses), np.mean(accuracies)))\n",
    "    return model_params"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "To show both plans work, first we run training with plain ops Plan and get updated model weights,\n",
    "then execute training with torchscript'ed Plan starting with updated weights and get further updated weights :)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "outputs": [
    {
     "name": "stdout",
     "text": [
      "Epoch 1, avg loss: 0.207544, avg training accuracy: 0.441848\n"
     ],
     "output_type": "stream"
    }
   ],
   "source": [
    "# Plain Plan\n",
    "updated_model_params = execute_training_plan(mnist, training_plan_ops, model_params)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% \n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "outputs": [
    {
     "name": "stdout",
     "text": [
      "Epoch 1, avg loss: 0.157933, avg training accuracy: 0.731726\n"
     ],
     "output_type": "stream"
    }
   ],
   "source": [
    "# Torchscript Plan\n",
    "updated_model_params = execute_training_plan(mnist, training_plan_ts, updated_model_params)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "## Step 7: Create Diff\n",
    "\n",
    "Naive diff is just a difference between original model weights and updated model weights. "
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "outputs": [
    {
     "name": "stdout",
     "text": [
      "[torch.Size([392, 784]), torch.Size([392]), torch.Size([10, 392]), torch.Size([10])]\n"
     ],
     "output_type": "stream"
    }
   ],
   "source": [
    "diff = [ model_params[i] - updated_model_params[i] for i in range(len(model_params)) ]\n",
    "\n",
    "print([ item.shape for item in diff ])"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "Let's wrap it in State to serialize."
   ],
   "metadata": {
    "collapsed": false
   }
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "outputs": [
    {
     "name": "stdout",
     "text": [
      "Writing State to e:\\ml/diff.pb\n"
     ],
     "output_type": "stream"
    }
   ],
   "source": [
    "diff_state = State(\n",
    "    owner=hook.local_worker,\n",
    "    state_placeholders=[PlaceHolder().instantiate(param) for param in diff]\n",
    ")\n",
    "serializeToBinPb(hook.local_worker, diff_state, 'diff.pb')"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n",
     "is_executing": false
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  },
  "pycharm": {
   "stem_cell": {
    "cell_type": "raw",
    "source": [],
    "metadata": {
     "collapsed": false
    }
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}